# %%
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import load_npz
from torch_factor import matrix_factor
# %%
Xtrain = load_npz("train.npz")
Xtrain = np.array(Xtrain.todense())
print(np.sum(Xtrain != 0))

Xtest = load_npz("test.npz")
Xtest = np.array(Xtest.todense())
print(np.sum(Xtest != 0))


# %%
U, V, record = matrix_factor(Xtrain, Xtest, k=50, l=0.01, max_iter=200)

# %%
